# subj01_voxel_vae_tensorboard_resume.py
# export CUDA_VISIBLE_DEVICES=0,1
# =============== 1. 环境 & 依赖 ===============
import os, sys, json, random, h5py, numpy as np
from tqdm import tqdm
import webdataset as wds
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from scipy.stats import pearsonr

sys.path.append("mindeye2_src")
from mindeye2_src.utils import (
    cosine_anneal, soft_clip_loss, batchwise_cosine_similarity, topk, seed_everything
)
from models import CentralFoveaAttention , VoxelVAE
seed_everything(42)

# =============== 2. 超参数 ===============
device        = torch.device('cuda:0')
data_path     = 'dataset'
subj          = 7
num_session   = 40
batch_size    = 32
test_batch_size = 3000
num_samples_per_epoch = 750 * num_session
num_iterations_per_epoch = num_samples_per_epoch // batch_size
num_epochs    = 300
hidden_dim    = 256
n_blocks      = 2
kl_beta       = 1e-4
clip_seq_dim = 257
clip_emb_dim = 768
output_dir    = f"./subj0{subj}_vae_eye"
os.makedirs(output_dir, exist_ok=True)

# =============== 3. 数据集 ===============
with h5py.File(f'{data_path}/betas_all_subj0{subj}_fp32_renorm.hdf5', 'r') as f:
    voxels = torch.tensor(f['betas'][:]).float()
num_voxels = voxels.shape[-1]

clip_emb_file = h5py.File(f'{data_path}/clip_embeddings.hdf5', 'r')
clip_embeddings = clip_emb_file['embeddings']  # (N, 257, 768)

def my_split_by_node(urls): return urls
train_url = f"{data_path}/wds/subj0{subj}/train/{{0..{num_session-1}}}.tar"
train_data = wds.WebDataset(train_url, resampled=True, nodesplitter=my_split_by_node) \
   .shuffle(750, initial=1500, rng=random.Random(42)) \
   .decode("torch") \
   .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy",
           olds_behav="olds_behav.npy") \
   .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=False,
                      drop_last=False, pin_memory=True, num_workers=4, prefetch_factor=2)

test_url = f"{data_path}/wds/subj0{subj}/new_test/0.tar"
test_data = wds.WebDataset(test_url, resampled=False, nodesplitter=my_split_by_node) \
   .shuffle(750, initial=1500, rng=random.Random(42)) \
   .decode("torch") \
   .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy",
           olds_behav="olds_behav.npy") \
   .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
test_dl  = DataLoader(test_data, batch_size=test_batch_size, shuffle=False,
                      drop_last=False, pin_memory=True, num_workers=4, prefetch_factor=2)

# =============== 4. 初始化模型 & 优化器 ===============
model = VoxelVAE(
    num_voxels=num_voxels,
    token_dim=768,
    num_tokens=257,
    hidden_dim=hidden_dim,
    n_blocks=n_blocks,
    drop=0.15
).to(device)

attn = CentralFoveaAttention(embed_dim=768, grid_size=16).to(device)

optimizer = torch.optim.AdamW(
    list(model.parameters()) + list(attn.parameters()), lr=3e-4
)
total_steps   = int(np.floor(num_epochs * num_iterations_per_epoch))
lr_scheduler  = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=3e-4, total_steps=total_steps,
    final_div_factor=1000, last_epoch=-1, pct_start=2/num_epochs
)
soft_loss_temps = cosine_anneal(0.004, 0.0075, num_epochs)

# =============== 5. TensorBoard ===============
tb_dir = os.path.join(output_dir, "tensorboard")
os.makedirs(tb_dir, exist_ok=True)
writer = SummaryWriter(tb_dir)

# =============== 6. 工具函数 ===============
def pearson_r_and_r2(pred, true):
    pred_np = pred.detach().cpu().numpy().astype(np.float64)
    true_np = true.detach().cpu().numpy().astype(np.float64)
    r_list = [pearsonr(pred_np[i], true_np[i])[0] for i in range(len(pred_np))]
    r = np.nanmean(r_list)
    ss_res = ((true_np - pred_np) ** 2).sum(axis=1)
    ss_tot = ((true_np - true_np.mean(axis=1, keepdims=True)) ** 2).sum(axis=1)
    r2 = 1 - (ss_res / ss_tot).mean()
    return r, r2

def kl_divergence(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()

# =============== 7. 断点恢复逻辑 ===============
def find_latest_checkpoint(out_dir):
    ckpts = [f for f in os.listdir(out_dir) if f.startswith("ckpt_") and f.endswith(".pt")]
    if not ckpts:
        return None, 0
    latest = max(ckpts, key=lambda x: int(x.split("_")[1].split(".")[0]))
    return os.path.join(out_dir, latest), int(latest.split("_")[1].split(".")[0])

ckpt_path, start_epoch = find_latest_checkpoint(output_dir)
if ckpt_path and os.path.isfile(ckpt_path):
    print(f"🔄  Resuming training from {ckpt_path} (epoch {start_epoch})")
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model"])
    attn.load_state_dict(ckpt["attn"])
    optimizer.load_state_dict(ckpt["optimizer"])
    lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
    # 手动修正 scheduler 的 last_epoch
    lr_scheduler.last_epoch = start_epoch * num_iterations_per_epoch - 1
else:
    start_epoch = 0
    print("🆕  No checkpoint found, start training from scratch.")

# =============== 8. 训练 ===============
model.train()
for epoch in range(start_epoch, num_epochs):
    losses, mse_losses, kl_losses, contrastive_losses, lrs = [], [], [], [], []

    embedding_iters = torch.zeros(num_iterations_per_epoch, batch_size, clip_seq_dim,clip_emb_dim).float()
    voxel_iters = torch.zeros(num_iterations_per_epoch, batch_size, num_voxels).float()

    iter = -1
    for behav0, _, _, _ in train_dl:
        embedding_idx = behav0[:, 0, 0].cpu().long().numpy()
        embedding0, embedding_sorted_idx = np.unique(embedding_idx, return_index=True)
        if len(embedding0) != len(embedding_idx):
            continue
        iter += 1
        embedding0 = torch.tensor(clip_embeddings[embedding0])
        embedding_iters[iter] = embedding0

        voxel_idx = behav0[:, 0, 5].cpu().long().numpy()
        voxel_sorted_idx = voxel_idx[embedding_sorted_idx]
        voxel0 = voxels[voxel_sorted_idx]
        voxel0 = torch.Tensor(voxel0)
        voxel_iters[iter] = voxel0

        if iter >= num_iterations_per_epoch - 1:
            break

    pbar = tqdm(range(num_iterations_per_epoch), desc=f"Epoch {epoch+1}/{num_epochs}")
    for step in pbar:
        optimizer.zero_grad()
        voxel = voxel_iters[step].to(device)

        # CLIP 嵌入
        image_rep = embedding_iters[step].detach().to(device)

        # VAE 前向
        z, recon, mu, logvar = model(voxel)
        queried_image_rep = attn(image_rep)

        # loss
        mse  = F.mse_loss(recon, voxel)
        kl   = kl_divergence(mu, logvar)
        z_norm = F.normalize(z.flatten(1), dim=-1)
        q_norm = F.normalize(queried_image_rep.flatten(1), dim=-1)
        contr = soft_clip_loss(z_norm, q_norm, temp=soft_loss_temps[epoch])
        loss  = mse + kl_beta * kl + contr

        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        losses.append(loss.item())
        mse_losses.append(mse.item())
        kl_losses.append(kl.item())
        contrastive_losses.append(contr.item())
        lrs.append(optimizer.param_groups[0]["lr"])
        pbar.set_postfix(loss=loss.item(), mse=mse.item(), kl=kl.item(), contr=contr.item(), lr=lrs[-1])

        global_step = epoch * num_iterations_per_epoch + step
        writer.add_scalar("train/loss",  np.mean(losses), global_step)
        writer.add_scalar("train/mse",   np.mean(mse_losses), global_step)
        writer.add_scalar("train/kl",    np.mean(kl_losses), global_step)
        writer.add_scalar("train/contr", np.mean(contrastive_losses), global_step)
        writer.add_scalar("train/lr",    lrs[-1], global_step)
        
    print(f"E{epoch+1:03d} "
      f"train_loss={np.mean(losses):.4f} "
      f"train_mse={np.mean(mse_losses):.4f} "
      f"train_kl={np.mean(kl_losses):.4f} "
      f"train_contr={np.mean(contrastive_losses):.4f} "
      f"lr={lrs[-1]:.6f}")

    # =============== 9. 测试 ===============
    if (epoch + 1) % 10 == 0:
        test_clip_emb = None
        test_voxel = None
        model.eval()
        with torch.no_grad():
            for behav, _, _, _ in test_dl:
                if test_clip_emb is None:
                    voxel = voxels[behav[:, 0, 5].cpu().long()]
                    image_idx = behav[:, 0, 0].cpu().long()
                    unique_image, sort_indices = torch.unique(image_idx, return_inverse=True)
                    for im in unique_image:
                        locs = torch.where(im == image_idx)[0]
                        if len(locs) == 1:
                            locs = locs.repeat(3)
                        elif len(locs) == 2:
                            locs = locs.repeat(2)[:3]
                        assert len(locs) == 3
                        if test_clip_emb is None:
                            test_clip_emb = torch.Tensor(clip_embeddings[im][None])
                            test_voxel = voxel[locs][None]
                        else:
                            test_clip_emb = torch.vstack((test_clip_emb, torch.Tensor(clip_embeddings[im][None])))
                            test_voxel = torch.vstack((test_voxel, voxel[locs][None]))

            test_voxel_mean = torch.mean(test_voxel, dim=1)

            random_samps = np.random.choice(np.arange(len(test_voxel_mean)), size=300, replace=False)
            image_rep = test_clip_emb[random_samps].to(device)
            test_voxel_mean = test_voxel_mean[random_samps].to(device)

            z, pred_voxel, mu, logvar = model(test_voxel_mean)
            queried_image_rep = attn(image_rep)

            z_norm = F.normalize(z.flatten(1), dim=-1)
            q_norm = F.normalize(queried_image_rep.flatten(1), dim=-1)

            labels = torch.arange(z.size(0), device=device)
            fwd = topk(batchwise_cosine_similarity(z_norm, q_norm), labels, k=1).item()
            bwd = topk(batchwise_cosine_similarity(q_norm, z_norm), labels, k=1).item()
            test_mse = F.mse_loss(pred_voxel, test_voxel_mean)
            test_kl  = kl_divergence(mu, logvar)
            r, r2 = pearson_r_and_r2(pred_voxel, test_voxel_mean)

            print(f"E{epoch+1:03d} "
                  f"loss={np.mean(losses):.4f} "
                  f"mse={test_mse:.4f} "
                  f"kl={test_kl:.4f} "
                  f"fwd={fwd:.3f} "
                  f"bwd={bwd:.3f} "
                  f"pearson_r={r:.4f} "
                  f"r2={r2:.4f}")

            writer.add_scalar("test/fwd",      fwd,  epoch+1)
            writer.add_scalar("test/bwd",      bwd,  epoch+1)
            writer.add_scalar("test/mse",      test_mse.item(), epoch+1)
            writer.add_scalar("test/kl",       test_kl.item(),  epoch+1)
            writer.add_scalar("test/pearson_r",r,    epoch+1)
            writer.add_scalar("test/r2",       r2,   epoch+1)

        # =============== 10. 保存权重 ===============
        torch.save({
            'model': model.state_dict(),
            'attn': attn.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
        }, f"{output_dir}/ckpt_{epoch+1:03d}.pt")

# =============== 11. 收尾 ===============
clip_emb_file.close()
writer.close()